{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data = pd.read_csv('train.csv')\n",
    "test_data = pd.read_csv('test.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from tensorflow.keras.layers import Dense, Conv2D, MaxPool2D, Dropout, Flatten\n",
    "from tensorflow.keras import Model\n",
    "\n",
    "from tensorflow.keras.utils import to_categorical\n",
    "\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>pixel0</th>\n",
       "      <th>pixel1</th>\n",
       "      <th>pixel2</th>\n",
       "      <th>pixel3</th>\n",
       "      <th>pixel4</th>\n",
       "      <th>pixel5</th>\n",
       "      <th>pixel6</th>\n",
       "      <th>pixel7</th>\n",
       "      <th>pixel8</th>\n",
       "      <th>pixel9</th>\n",
       "      <th>...</th>\n",
       "      <th>pixel774</th>\n",
       "      <th>pixel775</th>\n",
       "      <th>pixel776</th>\n",
       "      <th>pixel777</th>\n",
       "      <th>pixel778</th>\n",
       "      <th>pixel779</th>\n",
       "      <th>pixel780</th>\n",
       "      <th>pixel781</th>\n",
       "      <th>pixel782</th>\n",
       "      <th>pixel783</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 784 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "   pixel0  pixel1  pixel2  pixel3  pixel4  pixel5  pixel6  pixel7  pixel8  \\\n",
       "0       0       0       0       0       0       0       0       0       0   \n",
       "1       0       0       0       0       0       0       0       0       0   \n",
       "2       0       0       0       0       0       0       0       0       0   \n",
       "3       0       0       0       0       0       0       0       0       0   \n",
       "4       0       0       0       0       0       0       0       0       0   \n",
       "\n",
       "   pixel9  ...  pixel774  pixel775  pixel776  pixel777  pixel778  pixel779  \\\n",
       "0       0  ...         0         0         0         0         0         0   \n",
       "1       0  ...         0         0         0         0         0         0   \n",
       "2       0  ...         0         0         0         0         0         0   \n",
       "3       0  ...         0         0         0         0         0         0   \n",
       "4       0  ...         0         0         0         0         0         0   \n",
       "\n",
       "   pixel780  pixel781  pixel782  pixel783  \n",
       "0         0         0         0         0  \n",
       "1         0         0         0         0  \n",
       "2         0         0         0         0  \n",
       "3         0         0         0         0  \n",
       "4         0         0         0         0  \n",
       "\n",
       "[5 rows x 784 columns]"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_data.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>label</th>\n",
       "      <th>pixel0</th>\n",
       "      <th>pixel1</th>\n",
       "      <th>pixel2</th>\n",
       "      <th>pixel3</th>\n",
       "      <th>pixel4</th>\n",
       "      <th>pixel5</th>\n",
       "      <th>pixel6</th>\n",
       "      <th>pixel7</th>\n",
       "      <th>pixel8</th>\n",
       "      <th>...</th>\n",
       "      <th>pixel774</th>\n",
       "      <th>pixel775</th>\n",
       "      <th>pixel776</th>\n",
       "      <th>pixel777</th>\n",
       "      <th>pixel778</th>\n",
       "      <th>pixel779</th>\n",
       "      <th>pixel780</th>\n",
       "      <th>pixel781</th>\n",
       "      <th>pixel782</th>\n",
       "      <th>pixel783</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>4</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 785 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "   label  pixel0  pixel1  pixel2  pixel3  pixel4  pixel5  pixel6  pixel7  \\\n",
       "0      1       0       0       0       0       0       0       0       0   \n",
       "1      0       0       0       0       0       0       0       0       0   \n",
       "2      1       0       0       0       0       0       0       0       0   \n",
       "3      4       0       0       0       0       0       0       0       0   \n",
       "4      0       0       0       0       0       0       0       0       0   \n",
       "\n",
       "   pixel8  ...  pixel774  pixel775  pixel776  pixel777  pixel778  pixel779  \\\n",
       "0       0  ...         0         0         0         0         0         0   \n",
       "1       0  ...         0         0         0         0         0         0   \n",
       "2       0  ...         0         0         0         0         0         0   \n",
       "3       0  ...         0         0         0         0         0         0   \n",
       "4       0  ...         0         0         0         0         0         0   \n",
       "\n",
       "   pixel780  pixel781  pixel782  pixel783  \n",
       "0         0         0         0         0  \n",
       "1         0         0         0         0  \n",
       "2         0         0         0         0  \n",
       "3         0         0         0         0  \n",
       "4         0         0         0         0  \n",
       "\n",
       "[5 rows x 785 columns]"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_data.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import train_test_split"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "X = train_data.drop('label', axis=1) / 255\n",
    "Y = to_categorical(train_data['label'])\n",
    "\n",
    "X = np.array(X).reshape(42000, 28, 28, 1)\n",
    "X_predit = np.array(test_data / 255).reshape(28000, 28, 28, 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[[[0.]\n",
      "   [0.]\n",
      "   [0.]\n",
      "   ...\n",
      "   [0.]\n",
      "   [0.]\n",
      "   [0.]]\n",
      "\n",
      "  [[0.]\n",
      "   [0.]\n",
      "   [0.]\n",
      "   ...\n",
      "   [0.]\n",
      "   [0.]\n",
      "   [0.]]\n",
      "\n",
      "  [[0.]\n",
      "   [0.]\n",
      "   [0.]\n",
      "   ...\n",
      "   [0.]\n",
      "   [0.]\n",
      "   [0.]]\n",
      "\n",
      "  ...\n",
      "\n",
      "  [[0.]\n",
      "   [0.]\n",
      "   [0.]\n",
      "   ...\n",
      "   [0.]\n",
      "   [0.]\n",
      "   [0.]]\n",
      "\n",
      "  [[0.]\n",
      "   [0.]\n",
      "   [0.]\n",
      "   ...\n",
      "   [0.]\n",
      "   [0.]\n",
      "   [0.]]\n",
      "\n",
      "  [[0.]\n",
      "   [0.]\n",
      "   [0.]\n",
      "   ...\n",
      "   [0.]\n",
      "   [0.]\n",
      "   [0.]]]\n",
      "\n",
      "\n",
      " [[[0.]\n",
      "   [0.]\n",
      "   [0.]\n",
      "   ...\n",
      "   [0.]\n",
      "   [0.]\n",
      "   [0.]]\n",
      "\n",
      "  [[0.]\n",
      "   [0.]\n",
      "   [0.]\n",
      "   ...\n",
      "   [0.]\n",
      "   [0.]\n",
      "   [0.]]\n",
      "\n",
      "  [[0.]\n",
      "   [0.]\n",
      "   [0.]\n",
      "   ...\n",
      "   [0.]\n",
      "   [0.]\n",
      "   [0.]]\n",
      "\n",
      "  ...\n",
      "\n",
      "  [[0.]\n",
      "   [0.]\n",
      "   [0.]\n",
      "   ...\n",
      "   [0.]\n",
      "   [0.]\n",
      "   [0.]]\n",
      "\n",
      "  [[0.]\n",
      "   [0.]\n",
      "   [0.]\n",
      "   ...\n",
      "   [0.]\n",
      "   [0.]\n",
      "   [0.]]\n",
      "\n",
      "  [[0.]\n",
      "   [0.]\n",
      "   [0.]\n",
      "   ...\n",
      "   [0.]\n",
      "   [0.]\n",
      "   [0.]]]\n",
      "\n",
      "\n",
      " [[[0.]\n",
      "   [0.]\n",
      "   [0.]\n",
      "   ...\n",
      "   [0.]\n",
      "   [0.]\n",
      "   [0.]]\n",
      "\n",
      "  [[0.]\n",
      "   [0.]\n",
      "   [0.]\n",
      "   ...\n",
      "   [0.]\n",
      "   [0.]\n",
      "   [0.]]\n",
      "\n",
      "  [[0.]\n",
      "   [0.]\n",
      "   [0.]\n",
      "   ...\n",
      "   [0.]\n",
      "   [0.]\n",
      "   [0.]]\n",
      "\n",
      "  ...\n",
      "\n",
      "  [[0.]\n",
      "   [0.]\n",
      "   [0.]\n",
      "   ...\n",
      "   [0.]\n",
      "   [0.]\n",
      "   [0.]]\n",
      "\n",
      "  [[0.]\n",
      "   [0.]\n",
      "   [0.]\n",
      "   ...\n",
      "   [0.]\n",
      "   [0.]\n",
      "   [0.]]\n",
      "\n",
      "  [[0.]\n",
      "   [0.]\n",
      "   [0.]\n",
      "   ...\n",
      "   [0.]\n",
      "   [0.]\n",
      "   [0.]]]\n",
      "\n",
      "\n",
      " ...\n",
      "\n",
      "\n",
      " [[[0.]\n",
      "   [0.]\n",
      "   [0.]\n",
      "   ...\n",
      "   [0.]\n",
      "   [0.]\n",
      "   [0.]]\n",
      "\n",
      "  [[0.]\n",
      "   [0.]\n",
      "   [0.]\n",
      "   ...\n",
      "   [0.]\n",
      "   [0.]\n",
      "   [0.]]\n",
      "\n",
      "  [[0.]\n",
      "   [0.]\n",
      "   [0.]\n",
      "   ...\n",
      "   [0.]\n",
      "   [0.]\n",
      "   [0.]]\n",
      "\n",
      "  ...\n",
      "\n",
      "  [[0.]\n",
      "   [0.]\n",
      "   [0.]\n",
      "   ...\n",
      "   [0.]\n",
      "   [0.]\n",
      "   [0.]]\n",
      "\n",
      "  [[0.]\n",
      "   [0.]\n",
      "   [0.]\n",
      "   ...\n",
      "   [0.]\n",
      "   [0.]\n",
      "   [0.]]\n",
      "\n",
      "  [[0.]\n",
      "   [0.]\n",
      "   [0.]\n",
      "   ...\n",
      "   [0.]\n",
      "   [0.]\n",
      "   [0.]]]\n",
      "\n",
      "\n",
      " [[[0.]\n",
      "   [0.]\n",
      "   [0.]\n",
      "   ...\n",
      "   [0.]\n",
      "   [0.]\n",
      "   [0.]]\n",
      "\n",
      "  [[0.]\n",
      "   [0.]\n",
      "   [0.]\n",
      "   ...\n",
      "   [0.]\n",
      "   [0.]\n",
      "   [0.]]\n",
      "\n",
      "  [[0.]\n",
      "   [0.]\n",
      "   [0.]\n",
      "   ...\n",
      "   [0.]\n",
      "   [0.]\n",
      "   [0.]]\n",
      "\n",
      "  ...\n",
      "\n",
      "  [[0.]\n",
      "   [0.]\n",
      "   [0.]\n",
      "   ...\n",
      "   [0.]\n",
      "   [0.]\n",
      "   [0.]]\n",
      "\n",
      "  [[0.]\n",
      "   [0.]\n",
      "   [0.]\n",
      "   ...\n",
      "   [0.]\n",
      "   [0.]\n",
      "   [0.]]\n",
      "\n",
      "  [[0.]\n",
      "   [0.]\n",
      "   [0.]\n",
      "   ...\n",
      "   [0.]\n",
      "   [0.]\n",
      "   [0.]]]\n",
      "\n",
      "\n",
      " [[[0.]\n",
      "   [0.]\n",
      "   [0.]\n",
      "   ...\n",
      "   [0.]\n",
      "   [0.]\n",
      "   [0.]]\n",
      "\n",
      "  [[0.]\n",
      "   [0.]\n",
      "   [0.]\n",
      "   ...\n",
      "   [0.]\n",
      "   [0.]\n",
      "   [0.]]\n",
      "\n",
      "  [[0.]\n",
      "   [0.]\n",
      "   [0.]\n",
      "   ...\n",
      "   [0.]\n",
      "   [0.]\n",
      "   [0.]]\n",
      "\n",
      "  ...\n",
      "\n",
      "  [[0.]\n",
      "   [0.]\n",
      "   [0.]\n",
      "   ...\n",
      "   [0.]\n",
      "   [0.]\n",
      "   [0.]]\n",
      "\n",
      "  [[0.]\n",
      "   [0.]\n",
      "   [0.]\n",
      "   ...\n",
      "   [0.]\n",
      "   [0.]\n",
      "   [0.]]\n",
      "\n",
      "  [[0.]\n",
      "   [0.]\n",
      "   [0.]\n",
      "   ...\n",
      "   [0.]\n",
      "   [0.]\n",
      "   [0.]]]]\n"
     ]
    }
   ],
   "source": [
    "print(X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train,X_test,Y_train,Y_test=train_test_split(X,Y,test_size=0.2,random_state=101)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_ds = tf.data.Dataset.from_tensor_slices(\n",
    "            (X_train, Y_train)).shuffle(1000).batch(32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_ds = tf.data.Dataset.from_tensor_slices((X_test, Y_test)).batch(32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MyModel(Model):\n",
    "    def __init__(self):\n",
    "        super(MyModel, self).__init__()\n",
    "        self.conv1 = Conv2D(128, (3, 3), activation='relu', input_shape=(28, 28, 1))\n",
    "        self.pool1 = MaxPool2D(2, 2)\n",
    "        self.conv2 = Conv2D(128, (3, 3), activation='relu')\n",
    "        self.pool2 = MaxPool2D(2, 2)\n",
    "        self.conv3 = Conv2D(128, (3, 3), activation='relu')\n",
    "        self.flatten = Flatten()\n",
    "        self.d1 = Dense(120, activation='relu')\n",
    "        self.d2 = Dense(84, activation='relu')\n",
    "        self.d3 = Dense(10, activation='softmax')\n",
    "\n",
    "    def call(self, x):\n",
    "        x = self.conv1(x)\n",
    "        x = self.pool1(x)\n",
    "        x = self.conv2(x)\n",
    "        x = self.pool2(x)\n",
    "        x = self.conv3(x)\n",
    "        x = self.flatten(x)\n",
    "        x = self.d1(x)\n",
    "        x = self.d2(x)\n",
    "        return self.d3(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = MyModel()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "loss_object = tf.keras.losses.CategoricalCrossentropy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = tf.keras.optimizers.Adam()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loss = tf.keras.metrics.Mean(name='train_loss')\n",
    "train_accuracy = tf.keras.metrics.CategoricalAccuracy(name='train_accuracy')\n",
    "\n",
    "test_loss = tf.keras.metrics.Mean(name='test_loss')\n",
    "test_accuracy = tf.keras.metrics.CategoricalAccuracy(name='test_accuracy')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "@tf.function\n",
    "def train_step(images, labels):\n",
    "    with tf.GradientTape() as tape:\n",
    "        predictions = model(images)\n",
    "        loss = loss_object(labels, predictions)\n",
    "        gradients = tape.gradient(loss, model.trainable_variables)\n",
    "        optimizer.apply_gradients(zip(gradients, model.trainable_variables))\n",
    "\n",
    "    train_loss(loss)\n",
    "    train_accuracy(labels, predictions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "@tf.function\n",
    "def test_step(images, labels):\n",
    "    predictions = model(images)\n",
    "    t_loss = loss_object(labels, predictions)\n",
    "\n",
    "    test_loss(t_loss)\n",
    "    test_accuracy(labels, predictions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Layer my_model is casting an input tensor from dtype float64 to the layer's dtype of float32, which is new behavior in TensorFlow 2.  The layer has dtype float32 because it's dtype defaults to floatx.\n",
      "\n",
      "If you intended to run this layer in float32, you can safely ignore this warning. If in doubt, this warning is likely only an issue if you are porting a TensorFlow 1.X model to TensorFlow 2.\n",
      "\n",
      "To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.\n",
      "\n",
      "Epoch 1, Loss: 0.17766238749027252, Accuracy: 94.44642639160156, Test Loss: 0.059819288551807404, Test Accuracy: 98.19047546386719\n",
      "Epoch 2, Loss: 0.05355887860059738, Accuracy: 98.32440948486328, Test Loss: 0.05738315358757973, Test Accuracy: 98.28571319580078\n",
      "Epoch 3, Loss: 0.037625495344400406, Accuracy: 98.81845092773438, Test Loss: 0.0744447112083435, Test Accuracy: 97.88095092773438\n",
      "Epoch 4, Loss: 0.02999773435294628, Accuracy: 99.02083587646484, Test Loss: 0.048548366874456406, Test Accuracy: 98.41667175292969\n",
      "Epoch 5, Loss: 0.024043792858719826, Accuracy: 99.29166412353516, Test Loss: 0.044836994260549545, Test Accuracy: 98.55952453613281\n"
     ]
    }
   ],
   "source": [
    "EPOCHS = 5\n",
    "\n",
    "for epoch in range(EPOCHS):\n",
    "    for images, labels in train_ds:\n",
    "        train_step(images, labels)\n",
    "\n",
    "    for test_images, test_labels in test_ds:\n",
    "        test_step(test_images, test_labels)\n",
    "\n",
    "    template = 'Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}'\n",
    "    print(template.format(epoch+1,\n",
    "                    train_loss.result(),\n",
    "                    train_accuracy.result()*100,\n",
    "                    test_loss.result(),\n",
    "                    test_accuracy.result()*100))\n",
    "\n",
    "    # Reset the metrics for the next epoch\n",
    "    train_loss.reset_states()\n",
    "    train_accuracy.reset_states()\n",
    "    test_loss.reset_states()\n",
    "    test_accuracy.reset_states()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
